{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "from tensorboardX import SummaryWriter\n",
    "from torchvision import datasets, transforms\n",
    "import torchvision.utils as vutils\n",
    "from tqdm import tqdm, trange"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$y = x + 0.3 sin(2\\pi(x + \\epsilon)) + 0.3sin(4\\pi(x + \\epsilon)) + \\epsilon$$\n",
    "$$ \\epsilon \\sim N(0, 0.02)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000\n",
    "x = np.linspace(0, 0.5, N)[:, np.newaxis]\n",
    "x_test = np.linspace(-0.5, 1.0, 2 * N)[:, np.newaxis]\n",
    "eps = np.random.normal(0, 0.02, x.shape)\n",
    "y = x + 0.3 * np.sin(2 * np.pi * (x + eps))\\\n",
    "    + 0.3 * np.sin(4 * np.pi * (x + eps)) + eps\n",
    "y_test = x_test + 0.3 * np.sin(2 * np.pi * x_test)\\\n",
    "    + 0.3 * np.sin(4 * np.pi * x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f1caf60def0>]"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl4VOXZ+PHvnZAAESphUUAgKUWL2lbFyFLaurAIuGAvW4oGBEFjgm83rEqJVlCpS19eahULkaJgohX4ufBWFJJUiwooAUUJLiwSoFANu4AQkty/PzLhncQsk8xZZpL7c13nysyZ55znzhDmnvNsR1QVY4wxplKM3wEYY4yJLJYYjDHGVGGJwRhjTBWWGIwxxlRhicEYY0wVlhiMMcZUYYnBGGNMFZYYjDHGVGGJwRhjTBUt/A6gMTp27KjJycl+h2GMMVFl3bp1e1W1U33lojIxJCcnU1BQ4HcYxhgTVUSkKJRy1pRkjDGmCksMxhhjqrDEYIwxpgpLDMYYY6qwxGCMMaYKRxKDiMwXkS9FZGMtr4uI/EVEtojIhyLSJ+i1cSKyObCNcyIeY4wxjefUFcMzwLA6Xh8OnB3Y0oC/AohIe+A+oB/QF7hPRBIdiskYY0wjOJIYVHUlsL+OIiOBhVphDdBORLoAVwK5qrpfVQ8AudSdYIwxplnavHkz9957L3v27HG9Lq/6GM4CdgY93xXYV9v+bxCRNBEpEJGC4uJi1wI1xphI9Oqrr/Lggw9SWlrqel1eJQapYZ/Wsf+bO1WzVDVFVVM6dap3RrcxxjQp+fn5nH322XTv3t31urxKDLuA4N+mG7C7jv3GGGMCSktL+de//sUVV1zhSX1eJYalwE2B0Un9gUOqugdYDgwVkcRAp/PQwD5jjDEBBQUFfPXVVwwaNMiT+hxZRE9EngcuAzqKyC4qRhrFAajqHGAZMALYAhwDbg68tl9EHgDWBk51v6rW1YltjDHNTn5+PgCXX365J/U5khhU9YZ6Xlfg9lpemw/MdyIOY4xpivLz87nwwgvp2LGjJ/XZzGdjjIlgX3/9NatWrfKsfwEsMRhjTERbtWoVJ06c8Kx/ASwxGGNMRMvPz6dFixb8+Mc/9qxOSwzGGBPB8vLy6Nu3L23btvWsTksMxhgToQ4cOEBBQQFDhgzxtF5LDMYYE6HeeOMNVNUSgzHGmAq5ubm0adOGvn37elqvJQZjjIlQeXl5XHbZZcTFxXlaryUGY4yJQNu3b2fLli2eNyOBJQZjjIlIeXl5AAwePNjzui0xGGNMBMrLy6Nr166ce+65ntdticEYYyJMeXk5+fn5DB48GJGablvjLksMxhgTYTZs2MDevXt9aUYCSwzGGBNxKvsXvFwfKZglBmOMiTB5eXmcf/75dO3a1Zf6LTEYY0wEOX78OCtXrvStGQksMRhjTERZtWoVx48ft8RgjDGmQl5eHi1atODSSy/1LQZLDMYYE0Fyc3Pp37+/p8tsV2eJwRhjIkRxcTHr1q3zZRmMYI4kBhEZJiKfisgWEZlSw+uzROSDwPaZiBwMeq0s6LWlTsRjjDHRKDc3F1Vl+PDhvsbRItwTiEgsMBsYAuwC1orIUlXdVFlGVX8bVP6XwEVBp/haVS8MNw5jjIl2r732Gh07duTiiy/2NQ4nrhj6AltUdZuqlgB/B0bWUf4G4HkH6jXGmCajvLyc5cuXc+WVVxIT428rvxO1nwXsDHq+K7DvG0QkCfg28M+g3a1EpEBE1ojIdQ7EY4wxUWf9+vUUFxczbNgwv0MJvykJqGmFJ62l7GhgiaqWBe3roaq7RaQn8E8R+UhVt36jEpE0IA2gR48e4cZsjDER5fXXX0dEuPLKK/0OxZErhl1A96Dn3YDdtZQdTbVmJFXdHfi5DXiTqv0PweWyVDVFVVM6deoUbszGGBNRXnvtNS6++GIi4fPNicSwFjhbRL4tIvFUfPh/Y3SRiHwXSARWB+1LFJGWgccdgYHApurHGmNMU3bgwAHWrFkTEc1I4EBiUNVS4L+A5cDHwCJVLRSR+0Xk2qCiNwB/V9XgZqZzgQIR2QC8ATwcPJrJGNM0fPbZZ0yfPp0f/vCHtGvXjpiYGE477TQuuugiJk+ezNq1a/0O0Ve5ubmUl5f7Pky1klT9nI4OKSkpWlBQ4HcYJgrk5OSQmZlJUVFR2Oc67bTTOHr0KAAdOnTgscceIzU1NezzNmXbtm1jypQpLFmyBIB+/frRp08f2rdvz1dffUVhYSErV66kpKSEn/zkJ8ycOZOUlBSfo/behAkTeOmllyguLqZFCye6fmsmIutUtf43WFWjbrv44ovVmPpkZ2drQkKCUjEYwrUtKSlJs7Oz/f51I0p5ebn++c9/1oSEBG3Tpo1OnTpVd+/eXWPZQ4cO6axZs7Rz584qIjp58mQtKSnxOGL/lJeXa5cuXXTUqFGu1wUUaAifsb5/yDdms8Rg6pKdna1JSUmuJ4TgTUQ0IyPD7189Ihw7dkxvuOEGBfSqq67SnTt3hnTcwYMHNSMjQwEdOHCg7tmzx+VII8P69esV0Pnz57telyUG0yxlZ2drXFycp0khODk09yuHgwcPav/+/VVE9KGHHtLy8vIGn+P555/XhIQE7dWrl27fvt2FKCPLtGnTVET0iy++cL0uSwymWerQoYMvSaFy69ChgyYlJamINLsmpoMHD2q/fv20RYsWumTJkrDOtWrVKj399NO1e/fuum3bNocijEx9+vTRH/7wh57UFWpisNVVTdTKyckhOTmZmJgYkpOTycnJYd++fb7GtG/fPoqKilBVioqKSEtLIycnx9eYvHD8+HFGjBjBunXrWLJkCddff31Y5xswYABvvvkmR44cYdiwYezdu9ehSCPLrl27WL9+Pddee239hb0USvaItM2uGExNHctedDQ3ZouNjW3SVw5lZWX6i1/8QgFdtGiRo+d+6623tFWrVtqvXz89evSoo+eOBE8++aQCumnTJk/qw5qSTFPmdedyuFtCQkKTTQ733HOPAvrII4+4cv4XX3xRRUTHjh3bqD6LSDZs2DDt1auXZ7+XJQbTpImI7x/2Dd2SkpL8ftsct3TpUgV04sSJrn64TZs2TQGdPXu2a3V47fDhwxofH6+TJ0/2rM5QE4P1MZioFI0LKe7YscPvEBy1Y8cOxo0bx0UXXcQTTzyBSE3raTrj3nvvZcSIEfzmN7/h3Xffda0eL61YsYKSkpLI61/Abu1potSMGTNISEjwO4wGicZkVpuTJ08yevRoSktLWbRoEa1atXK1vpiYGLKzs+nSpQtjxozhyJEjrtbnhaVLl5KYmMjAgQP9DuUbLDGYqBE8CikzM5Nx48b5HVLIEhISmDFjht9hOOaBBx5g9erVzJs3j169enlSZ2JiIgsXLmTr1q3ccccdntTplpMnT/KPf/yDq666ytUlMBotlPamSNusj6H5qWkUkl8T2RqzRXLHc+VM8VDnXhQUFGhsbKzedNNNHkVY1Z133qmAvvLKK77U74TXX3/dl98B63w2TYlXo5Dc6NSO5E7nmhJu5XtQU5I4fvy4nn/++dq1a1fdv3+/LzEfP35cL7jgAj3zzDN9iyFcEydO1LZt2+rXX3/tab2hJgZrSjJRwYuO24SEBNLT00lKSnL0vCNGjKhxMl5N+7yWmZnJsWPHquyr+PyAoqIixo4dy6RJk069Nn36dAoLC3nqqadITEz0NNZKLVu2ZP78+ezdu5c777zTlxjCcfLkSV566SWuvfZa1/tmGi2U7BFpm10xND9uLXVRVxNKTEyMa1cmcXFxGh8fX2WfH3MdQrlCqlwDqqCgQGNiYvTmm2/2NMba3HXXXQpofn6+36E0SGUz0ssvv+x53VhTkmlK3EgMHTp0qLPOypU+vdzcbnaq3p8Q6vvao0cPvfjii7Vz58564MABV2MM1dGjR/U73/mOfuc739Fjx475HU7I/GpGUrXEYJoYN9r+Q/l23qZNG08Tg4i49h7W1oFf/cqlru25555zLb7GyM/PV0Dvvvtuv0MJSUlJibZv315vvPFGX+oPNTFYH4OJCk7PARARxo4dW2/b/pw5c4iLi3O07rq4Odehpv6EkydPhvT7iQiDBg1i9OjRboXXKFdccQU333wzM2fOpLCw0O9w6pWfn8/+/fv5+c9/7ncodQsle0TaZlcMzY+bd2Orr20/Ozvbk+W83e5jCOeqq0WLFvrpp5+6Fls4iouLNTExUS+99NKIX0vpxhtv1Hbt2unx48d9qR9rSjJNTXZ2tsbGxrryoRxq235GRobjzVpe3bshnCG/99xzj6uxhWvOnDkKaE5Ojt+h1Orw4cPaunVrTU9P9y0GTxMDMAz4FNgCTKnh9fFAMfBBYLsl6LVxwObANi6U+iwxNH01TbrKzs527dt6Q9v2s7OzG9Q2X19S8mI0UjhXXZHeuVtaWqqXXHKJdu7cWQ8ePOh3ODV6+umnFdBVq1b5FoNniQGIBbYCPYF4YANwXrUy44Enaji2PbAt8DMx8DixvjotMTRtTnSS1rbVdsUR6hVDcMJy8srBq6GqwfE3JL5Inrldae3atSoi+pvf/MbvUGp0+eWXe7rEdk28TAwDgOVBz38P/L5amdoSww3A3KDnc4Eb6qvTEkPT5tYs57i4OM3IyKjxBj+hfPC52c/RkOTkhIYOxY3k2dvB0tPTNTY2Vjds2OB3KFUUFRWpiOj06dN9jSPUxODEqKSzgJ1Bz3cF9lV3vYh8KCJLRKR7A481zYhbs5xFhIEDB5KVlUVSUhIiQlJSEllZWaSmptZ7fE2jepzk1bLcOTk5/PWvf23QMUVFRS5F46wZM2aQmJjIpEmTKC8v9zucUxYsWICqMmbMGL9DCYkTiaGmRdi12vP/BZJV9QdAHrCgAcdWFBRJE5ECESkoLi5udLAm8rk1ZLOkpITMzExSU1PZvn075eXlbN++PaSkAO5/cHu1LPevf/3rBh8TGxvrQiTOa9++PY888gjvvPMOzz77rN/hAFBaWkpWVhZDhgyhZ8+efocTEicSwy6ge9DzbsDu4AKquk9VTwSePgVcHOqxQefIUtUUVU3p1KmTA2GbSDVjxgzXbvoSzoe7mx/c8fHxri/LXbk20759+xp8bFlZmQsRuWP8+PEMGDCAO++8kwMHDvgdDsuWLWPXrl1kZGT4HUrInEgMa4GzReTbIhIPjAaWBhcQkS5BT68FPg48Xg4MFZFEEUkEhgb2mWYsNTW1ss/JceF8uNd0c6CYGGfmiJ48edKR89QmJyeHtLS0RjcJRcsVA1T8mzz55JPs27ePe+65x+9wmDNnDl27duXqq6/2O5TQhdIRUd8GjAA+o2J0UmZg3/3AtYHHDwGFVIxYegPoHXTsBCqGuW4Bbg6lPut8bvrc6oDOyMgIK66ahtE6NfmtvrWbwuHE+xltfvWrX6mI6OrVq32LYdu2bSoi+oc//MG3GIJhE9xMNHNrzoIbw0KdjDU7O1szMjJODauNjY0NO5mpathxRcuopGCHDx/Wbt266fe//30tKSnxJYbJkydrbGys7ty505f6q7PEYKKeG4nBrQ+5QYMGuRYvhH+lE86M8RYtWkTFPIaavPzyywroww8/7HndBw4c0DZt2mhqaqrnddfGEoOJem4tf+HWCqZuLtMdGxsbVmzh1O1mE5cXfvrTn2rr1q1169atntb70EMPKaAffPCBp/XWJdTEYKurmoiVlpbmynndGl00cOBAV84LDR8VFHx3uI4dO4ZVd2NGMUWSxx9/nBYtWpCRkVHxbdgDJ06c4LHHHmPIkCFccMEFntTpqFCyR6RtdsXQfDh9PwQ3l55w+77Uocbt9Aztyju4RbPHH39cAV24cKEn9T3xxBMKaG5urif1hQprSjJNgZMdu7Gxsa5+wLmZFCD0vhE3ElQ0dj4HKy0t1YEDB+rpp5+uO3bscLWuo0ePaufOnSNyGfBQE4M1JZmIlZOTQ2ZmpiPnSkhIYMGCBSHPcm4Mt8f6hzI5Lycnx5XlK7xarsMtsbGxLFy4kLKyMsaPH+/qchlPPPEE//nPf1ydqOm6ULJHpG12xdD0Odkc4vaVQiUnYq1rq+9bu5uL/EX7FUOlp556SgGdNWuWK+f/4osvtF27djp8+HBXzh8urCnJRDMnm0PcvI+yWzFX3+Lj4+tNbm7WH+19DJXKy8v1mmuu0ZYtW2pBQYHj57/55ps1Li5ON23a5Pi5nWCJwUQ1J+914NW33ezsbI2Li3MtMdTH6TvLeZ1YvVJcXKw9evTQpKQk3bt3r2PnfeeddxTQu+++27FzOi3UxGB9DCYitW/f3pHzeLE4XaXU1FSefvppOnTo4Pi5S0pKyMnJqbOMW8NwKz5Pmo6OHTuyZMkS9uzZw4033khpaWnY5zx69Cg333wz3bt3j4j1mcJlicFEnJycHA4fPuzIudq2betqh3N1qamp7N279/8uyR00ZswYkpOTa00QI0aMcLS+SklJSa6c10+XXHIJTz75JCtWrGDSpElh/1tNnjyZzZs3s3DhQtq0aeNQlD4K5bIi0jZrSmraorF/oTZuzN6uaS6GW7OuvbrlqF8yMzMV0MzMzEafY968eQroXXfd5WBk7sD6GEw0cnrxPL9H07jxYV35ezm5smtddTRl5eXleuuttyqgU6ZMafC8g+XLl2uLFi106NChevLkSZeidI4lBhN1nB5uGQnfdt0cKeTWWlLRvjZSQ5WWlmp6eroCOnbsWD1y5EhIx73yyivasmVL/cEPfqAHDx50OUpnWGIwUcfJD9FI+bbr5twCtza/m9/8UF5ervfff7+KiJ577rn6r3/9q9ayX3/9tU6dOlVFRC+55BLdt2+fh5GGxxKDiTpODbcMZcy/l9y6t4Rbm9/Nb37Ky8vTbt26KaBDhgzRZ599Vjdv3qxffPGFrl+/Xh966CHt0aOHAjpx4kQ9evSo3yE3iCUGE3WcvmKIJG4vsOfUFgnNb347evSo/vGPf9Tu3bvX+B796Ec/0ry8PL/DbJRQE4NUlI0uKSkpWlBQ4HcYxmGV9yU+duxY2OcSEVfXw2koJ383tyQlJTFjxgxPh/dGsvLyctavX89HH33EsWPHOOOMM7jkkktITk72O7RGE5F1qppSX7kWXgRjTChSU1N55513+Otf/xr2udya7NVYlR+2mZmZrixyF66kpCS2b9/udxgRJSYmhpSUFFJS6v0cbXJsgpuJKIsWLQr7HAkJCZ7Ndm6I1NRUtm/fTnZ2NgkJCX6Hc0qkvl/GP5YYTEQJ925hSUlJZGVlRXRzSGpqKllZWSQlJfm+LHM0vF/Ge44kBhEZJiKfisgWEZlSw+uTRWSTiHwoIvkikhT0WpmIfBDYljoRj2m+tm/fHhUfcpVXD+Xl5WRnZxMT4+13tISEBLKzs6Pm/TLeCvuvUURigdnAcOA84AYROa9asfeBFFX9AbAEeDTota9V9cLAdm248Zjmy43F67yQmprKwoULXY+/TZs2iIhdJZh6OfE1pS+wRVW3qWoJ8HdgZHABVX1DVSuHY6wBujlQr2mCwvlwHDVqlIOReKty8b3s7OxTi9aFe0e4yquQpKQksrOz+eqrrygvL7erBFMvJxLDWcDOoOe7AvtqMxF4Leh5KxEpEJE1InJdbQeJSFqgXEFxcXF4EZuIkpOTQ3JyMjExMRw/frzR51m2bJmDUfmjsolJVSktLQ1rvk9ZWRmqaonANJgTw1Vr6j2rcXKEiIwBUoBLg3b3UNXdItIT+KeIfKSqW79xQtUsIAsq5jGEH7aJBNXH9x89erTR54r2+xIbEymcuGLYBXQPet4N2F29kIgMBjKBa1X1ROV+Vd0d+LkNeBO4yIGYTJTIzMx0bNJXTEwMMTExdd6zwBhTPycSw1rgbBH5tojEA6OBKqOLROQiYC4VSeHLoP2JItIy8LgjMBDY5EBMJko4+S2/sumkqKiItLQ0Sw7GNFLYiUFVS4H/ApYDHwOLVLVQRO4XkcpRRn8C2gCLqw1LPRcoEJENwBvAw6pqiaEZcWKGck1DPY8dO0ZmZmbY5zamObK1koyvcnJymDBhAiUlJQ0+NjY2lgULFjB27Fhq+juOtPWSjPFbqGsl2cxn47vGfDkRERYsWEBqamqtVx2Rtl6SMdHCEoPxVWZmJidPnmzwcVdcccWpIZgzZsz4xtpDtv6PMY1nicH4qrGdz6tXrz7VuVx97SGb2WtMeKyPwfgqOTm50ctQ21LRxjSM9TGYqDBixIhGH1tUVGTzFoxxgSUG46twl7GweQvGOM8Sg/GVU3czs3kLxjjHEoPxVbgriAaztZKMcYYlBuOrsrKyBpVv1apVrUtz27wFY5zhxOqqxjRKY/oEjh8/TllZGfHx8VVmS9u8BWOcY1cMxje33XZbo447efIkbdu2tXkLxrjErhiMb8K598L+/fvZu3evg9EYYyrZFYOJStafYIx7LDGYqCMi1p9gjIssMRjfnHbaaY06TlWtP8EYF1liML6ZO3duo45LSkpyOBJjTDBLDMY377zzTqOOs2YkY9xlicH4pjFXDBkZGdaMZIzLLDEY3zT0tpuDBg3iySefdCkaY0wlSwwmKmRkZJCXl+d3GMY0C44kBhEZJiKfisgWEZlSw+stReSFwOvvikhy0Gu/D+z/VESudCIeEx1CHZWUkZFhVwrGeCjsxCAiscBsYDhwHnCDiJxXrdhE4ICq9gJmAY8Ejj0PGA2cDwwDngyczzQDc+fODWl11XDv2WCMaRgnrhj6AltUdZuqlgB/B0ZWKzMSWBB4vAQYJCIS2P93VT2hqp8DWwLnM81Aamoq8+fPr7ecLadtjLecSAxnATuDnu8K7KuxjKqWAoeADiEeC4CIpIlIgYgUFBcXOxC2iQSff/55vWVs+QtjvOVEYpAa9mmIZUI5tmKnapaqpqhqSqdOnRoYoolUs2bNqvN1W/7CGO85kRh2Ad2DnncDdtdWRkRaAKcD+0M81jRhhw4dqvN1VWXs2LEkJyfbPZ2N8YgTiWEtcLaIfFtE4qnoTF5arcxSYFzg8c+Af6qqBvaPDoxa+jZwNvCeAzGZKHDkyJGQyqkqRUVFpKWlWXIwxgNhJ4ZAn8F/AcuBj4FFqlooIveLyLWBYn8DOojIFmAyMCVwbCGwCNgEvA7crqoNu9ejiVpvvfVWg8ofO3aMzMxMl6IxxlRy5EY9qroMWFZt3x+CHh8Hfl7LsTMAa0RuhhozYc1GKBnjPpv5bHzTmMRgI5SMcZ8lBuOLL774gg8//JB27dqFfExCQoKNUDLGA5YYjC/y8/MBuOOOO0hISKizrIiQlJREVlaWraxqjAcc6WMwpqHy8vJo3749SUlJtG7dmmPHjtVatqGrsBpjwmOJwXhOVcnLy6NXr16kp6fXmRTsbm3GeM8Sg/Hc5s2b2blzJ0ePHq0zKVifgjH+sD4G47nc3FwA9u/fX2sZ61Mwxj+WGIzn8vLy6NSpExUL7H5T5fpIlhSM8YclBuOp0tJS3njjDb7++msqVkX5JlW1Gc7G+Mj6GIyn1q5dW+/CeWAznI3xk10xGE/l5uYiInTr1q3OcjbD2Rj/WGIwnsrNzSUlJYWHH3641oltNhrJGH9ZYjCeOXz4MKtXr2bIkCGkpqaSlZV1ap5CcEd069at/QrRGIMlBuOhN998k7KyMoYMGQJU3PN5xowZdOjQoUpH9L59+xg7diyTJk3yK1RjmjVLDMYzubm5nHbaaQwYMACAnJwc0tLS2Ldv3zfKqipz5syxG/MY4wNLDMYzK1as4NJLL6Vly5YAZGZm1jnzWVX59a9/7VV4xpgASwzGEzt27OCzzz471YxUua8++/bts6sGYzxmicF4onIZjKFDh57aF+qQVJvsZoy3LDEYT6xYsYKuXbty7rnnnto3Y8aMeu/FADbZzRivWWIwrisvLyc/P58hQ4ZUGZYaPGRVRIiJqfnP0Sa7GeOtsBKDiLQXkVwR2Rz4mVhDmQtFZLWIFIrIhyLyi6DXnhGRz0Xkg8B2YTjxmMj0/vvvs2/fvirNSJVSU1PZvn075eXlLFy48BtXEDbZzRjvhXvFMAXIV9WzgfzA8+qOATep6vnAMODPIhJ8o987VfXCwPZBmPGYCLRixQoABg8eXGe56lcQtvS2Mf4IdxG9kcBlgccLgDeBu4MLqOpnQY93i8iXQCfgYJh1myixYsUKLrjgAs4444x6y6amploiMMZn4V4xnKmqewACP+v8ny8ifYF4YGvQ7hmBJqZZItIyzHhMhDl8+DBvv/02w4cP9zsUY0yI6r1iEJE8oHMNLzVoDKGIdAGeBcapauXd3X8P/IeKZJFFxdXG/bUcnwakgXVGRpP8/HxKS0stMRgTRepNDKpaa8OwiHwhIl1UdU/gg//LWsp9C3gVuEdV1wSde0/g4QkReRr4XR1xZFGRPEhJSan5Di8m4ixbtozTTz/91DIYxpjIF25T0lJgXODxOOCV6gVEJB54CVioqourvdYl8FOA64CNYcZjIoiqsmzZMoYOHUpcXJzf4RhjQhRu5/PDwCIRmQjsAH4OICIpQLqq3gKMAn4CdBCR8YHjxgdGIOWISCdAgA+A9DDj8VRZWRlvv/02r732GoWFhezfv5+WLVvSq1cvfvzjH3P11VeTmPiNEbzNxocffsju3bsZMWKE36EYYxpAarvvbiRLSUnRgoIC3+ovLS3lb3/7G48++ijbtm0jPj6e3r1706lTJ44dO8ann37K/v37ad26NRMnTiQzM5POnWvqpmnaHnroIaZOncqePXua5e9vTKQRkXWqmlJfOZv53EAbNmygf//+pKen07FjR55//nn27t3Lhg0byMvLY9WqVRQXF7NmzRpGjx7N3Llz6d27N0899RTRmITDsWzZMvr06WNJwZgoY4mhAXJycujXrx87d+7khRdeOPXh37Zt2yrlYmJi6NevH/Pnz2fjxo306dOHtLQ0xo0bV+cy003JgQMHWLVqlTUjGROFLDGEQFW57777GDNmDP3792fjxo2MGjWqyro/tTnnnHPIy8tj+vTpZGdnM3jwYA4ebPpz+1asWEF5ebklBmOikCWGeqgqU6dO5f7772fChAnk5ubSqVMuhhPbAAAQBElEQVSnBp0jJiaGP/zhDyxevJiCggKuuOIK9u7d61LEkWHZsmW0b9+evn37+h2KMaaBLDHU47777uPhhx8mPT2dp556Kqxhl9dffz2vvPIKmzZt4uqrr26yzUplZWUsW7aMK6+8ktjYWL/DMcY0kCWGOsyfP58HHniACRMmMHv27FqXhW6I4cOH8/zzz/Pee+8xevRoSktLHYg0sqxatYq9e/dy3XXX+R2KMaYRLDHUIj8/n9tuu40hQ4YwZ84cR5JCpZ/+9Kc8/vjj/O///i933XWXY+eNFC+//DLx8fEMGzbM71CMMY0Q7gS3Jmnnzp2MGjWK7373uyxevNiVWbu33347n3zyCbNmzaJ///6MGjXK8Tr8oKq88sorDBo0iG9961t+h2OMaQS7Yqjm5MmTjB49mpKSEl588UVOP/101+qaOXMmAwYMYMKECXz88ceu1eOlwsJCtm7das1IxkQxSwzVZGZmsmrVKubNm8c555zjal3x8fEsXryYhIQEbrjhBk6cOOFqfV54+eWXAbjmmmt8jsQY01iWGILk5eXxpz/9iYyMDH7xi1/Uf4ADzjrrLP72t7+xYcMG7rvvPk/qdNPLL79M//796dKli9+hGGMayRJDwOHDh5kwYQK9e/dm5syZntZ9zTXXcOutt/Loo4+ycuVKT+t20s6dO1m3bp01IxkT5SwxBNxxxx38+9//5plnnqF169ae1/8///M/9OzZk5tuuokjR454Xr8TKpuRRo4c6XMkxphwWGIAXn/9debNm8ddd91Fv379fImhTZs2LFiwgKKiIu69915fYgjXCy+8wPe+9z169+7tdyjGmDA0+8Rw6NAhbrnlFs4//3ymTZvmaywDBw4kIyODxx57jPfee8/XWBpq586dvPPOO4wePdrvUIwxYWr2iSEzM5M9e/bw9NNP07JlS7/D4aGHHqJr167ccsstnDx50u9wQrZ4ccXN+bzqtDfGuKdZJ4aCggKefPJJbr/9di655BK/wwHg9NNPZ/bs2Xz00Uf893//t9/hhOyFF16gT58+9OrVy+9QjDFharaJoaysjPT0dM4880weeOABv8OpYuTIkVx//fVMnz6dzZs3+x1OvT7//HPee+89u1owpolotolhzpw5rFu3jlmzZrk6u7mx/vKXv9CyZUsmTZoU8Xd+W7RoEUCTWdbDmOauWSaG//znP0ydOpXBgwdH7Lfcrl278sc//pG8vDyee+45v8OplaqSnZ1N//79SU5O9jscY4wDwkoMItJeRHJFZHPgZ2It5cpE5IPAtjRo/7dF5N3A8S+ISHw48YTqjjvu4Pjx48yePTuku7D5JT09nb59+zJ58mT279/vdzg1Wr9+PRs3bmT8+PF+h2KMcUi4VwxTgHxVPRvIDzyvydeqemFguzZo/yPArMDxB4CJYcZTr/z8fJ577jmmTJni+lpI4YqNjWXu3Lns27ePKVNqe2v99cwzz9CyZcuIvfIyxjSchNN+LSKfApep6h4R6QK8qarfraHcEVVtU22fAMVAZ1UtFZEBwDRVvbK+elNSUrSgoKDB8Z44cYIf/OAHlJWVsXHjRlq1atXgc/jhd7/7HTNnzuTtt99m4MCBfodzyokTJ+jatStDhw7l+eef9zscY0w9RGSdqqbUVy7cK4YzVXUPQODnGbWUayUiBSKyRkQqF9LpABxU1cpbmO0Czgoznjo9+uijfPbZZ8yePTtqkgLAtGnT6NGjB7fddhslJSV+h3PKP/7xD/bv32/NSMY0MfUmBhHJE5GNNWwNWRCnRyBL3Qj8WUS+A9TUuF/r5YuIpAWSS0FxcXEDqv4/u3fvZtSoUVx5Zb0XJRGlTZs2PPHEExQWFnq+wF9dnn76abp27crgwYP9DsUY4yBPmpKqHfMM8A/g/+FxUxJAaWkpLVpE543rrr/+epYtW0ZhYSE9e/b0NZbt27fTs2dPpk6dyoMPPuhrLMaY0HjVlLQUGBd4PA54pYZAEkWkZeBxR2AgsEkrMtIbwM/qOt5p0ZoUoGJuQ1xcXETMbai8D/Ztt93maxzGGOeFmxgeBoaIyGZgSOA5IpIiIvMCZc4FCkRkAxWJ4GFV3RR47W5gsohsoaLP4W9hxtOknXXWWTz44IMsX7781KQyPxw/fpx58+YxcuRIunfv7lscxhh3hNWU5JdwmpKiXVlZGf369WPXrl188skntGvXzvMYFixYwPjx48nPz+eKK67wvH5jTON41ZRkPBYbG0tWVhbFxcX8/ve/97x+VeUvf/kL5557Lpdffrnn9Rtj3GeJIQr16dOHX/3qV8ydO5fVq1d7Wndubi7r169n8uTJET1r3BjTeNaUFKW++uorzjvvPBITE1m3bh1xcXGe1Hv55ZezefNmtm7dGhH3rzDGhM6akpq4tm3b8vjjj/PRRx/x5z//2ZM6V61axZtvvskdd9xhScGYJsyuGKLcddddx4oVK9i0aZPrq5sOGzaMtWvXUlRURJs2beo/wBgTUeyKoZl4/PHHiYmJ4fbbb3d1bkN+fj7Lly9n6tSplhSMaeIsMUS57t2788ADD7Bs2TKeffZZV+ooLy/nrrvuokePHtx+++2u1GGMiRyWGJqAX/7yl/zkJz8hIyODjz/+2PHzP/vss6xfv54HH3wwqhYfNMY0jvUxNBG7d+/mwgsv5Mwzz+Tdd98lISHBkfPu3buX3r17893vfpe33nqLmBj7LmFMtLI+hmama9euZGdnU1hYSFpammP9Db/97W85dOgQc+fOtaRgTDNh/9ObkKFDh/LAAw+Qk5PD9OnTwz5fTk4O2dnZTJ06le9973sORGiMiQbRu9SoqdHUqVPZunUr06dPp1u3btxyyy2NOs8nn3xCeno6P/rRj7j33nsdjtIYE8ksMTQxIsLcuXPZs2cPt956K2VlZQ1eGnvPnj0MHz6chIQEnnvuuaheqtwY03DWlNQExcXF8dJLL3HVVVeRnp7O9OnTKS8vD+nYnTt3MmjQIIqLi3n11VdtWW1jmiFLDE1Uq1atePHFF7npppuYNm0aV199NTt27KjzmJUrVzJgwAD+/e9/8+qrr5KSUu/gBWNME2SJoQmLj4/nmWeeYfbs2bzxxhv07t2b3/3udxQWFp4ataSqvPfee4wZM4bLLruM1q1bs3LlSi699FKfozfG+MXmMTQTRUVFTJkyhcWLF1NWVka7du1o3749X375JUeOHKF169b88pe/5N5777UlL4xpokKdx2CJoZnZvXs3r776Ku+//z6HDh2iY8eOpKSkcPXVV5OYmOh3eMYYF4WaGGy4STPTtWtXbr31Vr/DMMZEMOtjMMYYU4UlBmOMMVWElRhEpL2I5IrI5sDPbzRSi8jlIvJB0HZcRK4LvPaMiHwe9NqF4cRjjDEmfOFeMUwB8lX1bCA/8LwKVX1DVS9U1QuBK4BjwIqgIndWvq6qH4QZjzHGmDCFmxhGAgsCjxcA19VT/mfAa6p6LMx6jTHGuCTcxHCmqu4BCPw8o57yo4Hnq+2bISIfisgsEan1DvMikiYiBSJSUFxcHF7UxhhjalVvYhCRPBHZWMM2siEViUgX4PvA8qDdvwd6A5cA7YG7azteVbNUNUVVUzp16tSQqo0xxjRAvfMYVHVwba+JyBci0kVV9wQ++L+s41SjgJdU9WTQufcEHp4QkaeB34UYtzHGGJeEO8FtKTAOeDjw85U6yt5AxRXCKUFJRajon9gYSqXr1q3bKyJFjQvZcR2BvX4HUY9IjzHS4wOL0QmRHh9EfozhxpcUSqGwlsQQkQ7AIqAHsAP4uaruF5EUIF1VbwmUSwbeAbqrannQ8f8EOgECfBA45kijA/KBiBSEMsXcT5EeY6THBxajEyI9Poj8GL2KL6wrBlXdBwyqYX8BcEvQ8+3AWTWUuyKc+o0xxjjPZj4bY4ypwhJD+LL8DiAEkR5jpMcHFqMTIj0+iPwYPYkvKpfdNsYY4x67YjDGGFOFJYYGCmXhwKCy3xKRf4vIE5EWo4hcKCKrRaQwMPP8Fx7ENUxEPhWRLSLyjXW1RKSliLwQeP3dwGg2T4UQ42QR2RR4z/JFJKThf17FF1TuZyKigRGCngolRhEZFXgfC0XkuUiLUUR6iMgbIvJ+4N96hMfxzReRL0WkxiH8UuEvgfg/FJE+jgagqrY1YAMeBaYEHk8BHqmj7GPAc8ATkRYjcA5wduBxV2AP0M7FmGKBrUBPIB7YAJxXrcwkYE7g8WjgBY/ft1BivBxICDzO8DLGUOILlGsLrATWACkR+B6eDbwPJAaenxGBMWYBGYHH5wHbPY7xJ0AfYGMtr48AXqNiqH9/4F0n67crhoYLaeFAEbkYOJOqK8l6pd4YVfUzVd0ceLybilnrbq410hfYoqrbVLUE+HsgzmDBcS8BBgUmP3ql3hi1YrXgykUg1wDdIim+gAeo+HJw3MPYKoUS463AbFU9AKCqda2Y4FeMCnwr8Ph0YLeH8aGqK4H9dRQZCSzUCmuAdoHVJxxhiaHh6l04UERigJnAnR7HVqlBixuKSF8qvjltdTGms4CdQc938c25LafKqGopcAjo4GJM1YUSY7CJVHxr80q98YnIRVRMJP2Hh3EFC+U9PAc4R0TeEZE1IjLMs+gqhBLjNGCMiOwClgG/9Ca0kDX0b7VB7J7PNRCRPKBzDS9lhniKScAyVd3p1hdeB2KsPE8X4FlgnAbNSndBTW9E9SFxoZRxU8j1i8gYIAW41NWIqlVbw75T8QW+kMwCxnsVUA1CeQ9bUNGcdBkVV1xvicj3VPWgy7FVCiXGG4BnVHWmiAwAng3E6Ob/kYZw9f+KJYYaaPgLBw4Afiwik4A2QLyIHFHVWjsLfYgREfkW8CpwT+By1E27gO5Bz7vxzcvzyjK7RKQFFZfwdV1OOy2UGBGRwVQk4EtV9YRHsUH98bUFvge8GfhC0hlYKiLXasVqBJEQY2WZNVqxoObnIvIpFYlirTchhhTjRGAYgKquFpFWVKxT5HWzV21C+lttLGtKarjKhQOhloUDVTVVVXuoajIVK8YudDIphKDeGEUkHngpENtiD2JaC5wtIt8O1D06EGew4Lh/BvxTAz1tHqk3xkBTzVzgWh/axuuMT1UPqWpHVU0O/O2tCcTpVVKoN8aAl6noxEdEOlLRtLQtwmLcQWC5HxE5F2gFRNKNYJYCNwVGJ/UHDun/rVYdPi972pvCRkWbdz6wOfCzfWB/CjCvhvLj8X5UUr0xAmOAk1QsXli5XehyXCOAz6joy8gM7Lufig8vqPjPtxjYArwH9PTh37e+GPOAL4Les6WRFF+1sm/i8aikEN9DAf4H2AR8BIyOwBjPo2Lhzw2Bf+ehHsf3PBUjBU9ScXUwEUinYqHRyvdwdiD+j5z+d7aZz8YYY6qwpiRjjDFVWGIwxhhThSUGY4wxVVhiMMYYU4UlBmOMMVVYYjDGGFOFJQZjjDFVWGIwxhhTxf8HbkBgB/Nh+jkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f1cafec77b8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.scatter(x, y, color='k')\n",
    "plt.plot(x_test, y_test, color='k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 1)"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Gaussian(object):\n",
    "    def __init__(self, mu, rho):\n",
    "        self.mu = mu\n",
    "        self.rho = rho\n",
    "        self.normal = torch.distributions.Normal(0, 1)\n",
    "        \n",
    "    def sample(self):\n",
    "        epsilon = self.normal.sample(self.rho.size())\n",
    "        return self.mu + self.sigma * epsilon\n",
    "    \n",
    "    @property\n",
    "    def sigma(self):\n",
    "        return torch.log1p(torch.exp(self.rho))\n",
    "    \n",
    "    def log_prob(self, input):\n",
    "        return (-0.5 * math.log(2 * math.pi) - torch.log(self.sigma)\\\n",
    "            - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ScaleMixtureGaussian(object):\n",
    "    def __init__(self, pi, sigma1, sigma2):\n",
    "        self.pi = pi\n",
    "        self.sigma1 = sigma1\n",
    "        self.sigma2 = sigma2\n",
    "        self.gaussian1 = torch.distributions.Normal(0, sigma1)\n",
    "        self.gaussian2 = torch.distributions.Normal(0, sigma2)\n",
    "        \n",
    "    def log_prob(self, input):\n",
    "        prob1 = torch.exp(self.gaussian1.log_prob(input))\n",
    "        prob2 = torch.exp(self.gaussian2.log_prob(input))\n",
    "        return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [],
   "source": [
    "PI = 0.5\n",
    "SIGMA_1 = torch.FloatTensor([math.exp(-0)])\n",
    "SIGMA_2 = torch.FloatTensor([math.exp(-6)])\n",
    "\n",
    "class BayesianLinear(nn.Module):\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        # Weight paramters\n",
    "        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))\n",
    "        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5, -4))\n",
    "        self.weight = Gaussian(self.weight_mu, self.weight_rho)\n",
    "        # Bias parameters\n",
    "        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))\n",
    "        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5, -4))\n",
    "        self.bias = Gaussian(self.bias_mu, self.bias_rho)\n",
    "        # Prior distributions\n",
    "        self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)\n",
    "        self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)\n",
    "        self.log_prior = 0\n",
    "        self.log_variational_posterior = 0\n",
    "        \n",
    "    def forward(self, input, sample=False, calculate_log_probs=False):\n",
    "        if self.training or sample:\n",
    "            weight = self.weight.sample()\n",
    "            bias = self.bias.sample()\n",
    "        else:\n",
    "            weight = self.weight.mu\n",
    "            bias = self.bias.mu\n",
    "        if self.training or calculate_log_probs:\n",
    "            self.log_prior = self.weight_prior.log_prob(weight)\\\n",
    "                + self.bias_prior.log_prob(bias)\n",
    "            self.log_variational_posterior = self.weight.log_prob(weight)\\\n",
    "                + self.bias.log_prob(bias)\n",
    "        else:\n",
    "            self.log_prior = 0\n",
    "            self.log_variational_posterior = 0\n",
    "        return F.linear(input, weight, bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLES = 10\n",
    "BATCH_SIZE = 32\n",
    "OUTPUT_DIM = 1\n",
    "\n",
    "class BayesianNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.l1 = BayesianLinear(1, 20)\n",
    "        self.l2 = BayesianLinear(20, 20)\n",
    "        self.l3 = BayesianLinear(10, 1)\n",
    "        \n",
    "    def forward(self, x, sample=False):\n",
    "        x = F.relu(self.l1(x, sample))\n",
    "        x = F.relu(self.l2(x, sample))\n",
    "        x = self.l3(x, sample)\n",
    "        return x\n",
    "    \n",
    "    def log_prior(self):\n",
    "        return self.l1.log_prior\\\n",
    "            + self.l2.log_prior\\\n",
    "            + self.l3.log_prior\n",
    "    \n",
    "    def log_variational_posterior(self):\n",
    "        return self.l1.log_variational_posterior\\\n",
    "            + self.l2.log_variational_posterior\\\n",
    "            + self.l3.log_variational_posterior\n",
    "            \n",
    "    def sample_elbo(self, input, target,\n",
    "                    batch_size=BATCH_SIZE, \n",
    "                    output_dim=OUTPUT_DIM,\n",
    "                    samples=SAMPLES):\n",
    "        outputs = torch.zeros(samples, batch_size, output_dim)\n",
    "        log_priors = torch.zeros(samples)\n",
    "        log_variational_posteriors = torch.zeros(samples)\n",
    "        for i in range(samples):\n",
    "            outputs[i] = self(input, sample=True)\n",
    "            log_priors[i] = self.log_prior()\n",
    "            log_variational_posteriors[i] = self.log_variational_posterior()\n",
    "        log_prior = log_priors.mean()\n",
    "        log_variational_posterior = log_variational_posteriors.mean()\n",
    "        mse_loss = F.mse_loss(outputs.mean(0), target, size_average=False)\n",
    "        loss = (log_variational_posterior - log_prior) / NUM_BATCHES\\\n",
    "            + mse_loss\n",
    "        return loss, log_prior, log_variational_posterior, mse_loss\n",
    "    \n",
    "net = BayesianNetwork()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = SummaryWriter()\n",
    "\n",
    "def write_weight_histograms(epoch):\n",
    "    writer.add_histogram('histogram/w1_mu', net.l1.weight_mu,epoch)\n",
    "    writer.add_histogram('histogram/w1_rho', net.l1.weight_rho,epoch)\n",
    "    writer.add_histogram('histogram/w2_mu', net.l2.weight_mu,epoch)\n",
    "    writer.add_histogram('histogram/w2_rho', net.l2.weight_rho,epoch)\n",
    "    writer.add_histogram('histogram/w3_mu', net.l3.weight_mu,epoch)\n",
    "    writer.add_histogram('histogram/w3_rho', net.l3.weight_rho,epoch)\n",
    "    writer.add_histogram('histogram/b1_mu', net.l1.bias_mu,epoch)\n",
    "    writer.add_histogram('histogram/b1_rho', net.l1.bias_rho,epoch)\n",
    "    writer.add_histogram('histogram/b2_mu', net.l2.bias_mu,epoch)\n",
    "    writer.add_histogram('histogram/b2_rho', net.l2.bias_rho,epoch)\n",
    "    writer.add_histogram('histogram/b3_mu', net.l3.bias_mu,epoch)\n",
    "    writer.add_histogram('histogram/b3_rho', net.l3.bias_rho,epoch)\n",
    "\n",
    "def write_loss_scalars(epoch, batch_idx, loss, log_prior, log_variational_posterior, negative_log_likelihood):\n",
    "    writer.add_scalar('logs/loss', loss, epoch*NUM_BATCHES+batch_idx)\n",
    "    writer.add_scalar('logs/complexity_cost', log_variational_posterior-log_prior, epoch*NUM_BATCHES+batch_idx)\n",
    "    writer.add_scalar('logs/log_prior', log_prior, epoch*NUM_BATCHES+batch_idx)\n",
    "    writer.add_scalar('logs/log_variational_posterior', log_variational_posterior, epoch*NUM_BATCHES+batch_idx)\n",
    "    writer.add_scalar('logs/negative_log_likelihood', negative_log_likelihood, epoch*NUM_BATCHES+batch_idx)\n",
    "\n",
    "def train(net, optimizer, epcoh):\n",
    "    net.train()\n",
    "    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):\n",
    "        net.zero_grad()\n",
    "        loss, log_prior, log_variational_posterior, mse_loss = net.sample_elbo(data, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        write_loss_scalars(epoch, batch_idx, loss, log_prior, log_variational_posterior, mse_loss)\n",
    "    write_weight_histograms(epoch+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/600 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "size mismatch, m1: [2800 x 28], m2: [1 x 20] at /pytorch/aten/src/TH/generic/THTensorMath.c:2033",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-203-cec21c699ccf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTRAIN_EPOCHS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-202-cc82eb972664>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(net, optimizer, epcoh)\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m         \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_prior\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_variational_posterior\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmse_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_elbo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-201-54754f74736c>\u001b[0m in \u001b[0;36msample_elbo\u001b[0;34m(self, input, target, batch_size, output_dim, samples)\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0mlog_variational_posteriors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m             \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m             \u001b[0mlog_priors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_prior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m             \u001b[0mlog_variational_posteriors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_variational_posterior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    489\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    490\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    492\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    493\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-201-54754f74736c>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, sample)\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ml1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ml2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ml3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    489\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    490\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    492\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    493\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-194-33d8f213ff05>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, sample, calculate_log_probs)\u001b[0m\n\u001b[1;32m     35\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_prior\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_variational_posterior\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m    992\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 994\u001b[0;31m     \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    995\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    996\u001b[0m         \u001b[0moutput\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: size mismatch, m1: [2800 x 28], m2: [1 x 20] at /pytorch/aten/src/TH/generic/THTensorMath.c:2033"
     ]
    }
   ],
   "source": [
    "optimizer = optim.Adam(net.parameters())\n",
    "for epoch in range(TRAIN_EPOCHS):\n",
    "    train(net, optimizer, epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_ensemble(net):\n",
    "    net.eval()\n",
    "    correct = 0\n",
    "    corrects = np.zeros(TEST_SAMPLES + 1, dtype=int)\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            outputs = torch.zeros(TEST_SAMPLES + 1, TEST_BATCH_SIZE, CLASSES)\n",
    "            for i in range(TEST_SAMPLES):\n",
    "                outputs[i] = net(data, sample=True)\n",
    "            outputs[TEST_SAMPLES] = net(data, sample=False)\n",
    "            output = outputs.mean(0)\n",
    "            preds = outputs.max(2, keepdim=True)[1]\n",
    "            pred = output.max(1, keepdim=True)[1]\n",
    "            corrects += preds.eq(target.view_as(pred)).sum(dim=1).squeeze().numpy()\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    for index, num in enumerate(corrects):\n",
    "        if index < TEST_SAMPLES:\n",
    "            print('Component {} Accuracy: {}/{}'.format(index, num, TEST_SIZE))\n",
    "        else:\n",
    "            print('Posterior Mean Accuracy: {}/{}'.format(num, TEST_SIZE))\n",
    "    print('Ensemble Accuracy: {}/{}'.format(correct, TEST_SIZE))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Component 0 Accuracy: 8551/10000\n",
      "Component 1 Accuracy: 8531/10000\n",
      "Component 2 Accuracy: 8549/10000\n",
      "Component 3 Accuracy: 8565/10000\n",
      "Component 4 Accuracy: 8554/10000\n",
      "Component 5 Accuracy: 8546/10000\n",
      "Component 6 Accuracy: 8550/10000\n",
      "Component 7 Accuracy: 8553/10000\n",
      "Component 8 Accuracy: 8536/10000\n",
      "Component 9 Accuracy: 8542/10000\n",
      "Posterior Mean Accuracy: 8567/10000\n",
      "Ensemble Accuracy: 8574/10000\n"
     ]
    }
   ],
   "source": [
    "test_ensemble(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 1.1146,  0.4821, -0.5735],\n",
       "         [ 0.6798,  0.0434,  1.1977],\n",
       "         [ 0.7786,  1.6315,  0.9597]]), tensor([[ 2,  1,  2],\n",
       "         [ 2,  1,  1],\n",
       "         [ 1,  0,  2]]))"
      ]
     },
     "execution_count": 151,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.max(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on built-in function max:\n",
      "\n",
      "max(...) method of torch.Tensor instance\n",
      "    max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor)\n",
      "    \n",
      "    See :func:`torch.max`\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(x.max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on built-in function max:\n",
      "\n",
      "max(...)\n",
      "    .. function:: max(input) -> Tensor\n",
      "    \n",
      "    Returns the maximum value of all elements in the :attr:`input` tensor.\n",
      "    \n",
      "    Args:\n",
      "        input (Tensor): the input tensor\n",
      "    \n",
      "    Example::\n",
      "    \n",
      "        >>> a = torch.randn(1, 3)\n",
      "        >>> a\n",
      "        tensor([[ 0.6763,  0.7445, -2.2369]])\n",
      "        >>> torch.max(a)\n",
      "        tensor(0.7445)\n",
      "    \n",
      "    .. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)\n",
      "    \n",
      "    Returns the maximum value of each row of the :attr:`input` tensor in the given\n",
      "    dimension :attr:`dim`. The second return value is the index location of each\n",
      "    maximum value found (argmax).\n",
      "    \n",
      "    If :attr:`keepdim` is ``True``, the output tensors are of the same size\n",
      "    as :attr:`input` except in the dimension :attr:`dim` where they are of size 1.\n",
      "    Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting\n",
      "    in the output tensors having 1 fewer dimension than :attr:`input`.\n",
      "    \n",
      "    Args:\n",
      "        input (Tensor): the input tensor\n",
      "        dim (int): the dimension to reduce\n",
      "        keepdim (bool): whether the output tensors have :attr:`dim` retained or not\n",
      "        out (tuple, optional): the result tuple of two output tensors (max, max_indices)\n",
      "    \n",
      "    Example::\n",
      "    \n",
      "        >>> a = torch.randn(4, 4)\n",
      "        >>> a\n",
      "        tensor([[-1.2360, -0.2942, -0.1222,  0.8475],\n",
      "                [ 1.1949, -1.1127, -2.2379, -0.6702],\n",
      "                [ 1.5717, -0.9207,  0.1297, -1.8768],\n",
      "                [-0.6172,  1.0036, -0.6060, -0.2432]])\n",
      "        >>> torch.max(a, 1)\n",
      "        (tensor([ 0.8475,  1.1949,  1.5717,  1.0036]), tensor([ 3,  0,  0,  1]))\n",
      "    \n",
      "    .. function:: max(input, other, out=None) -> Tensor\n",
      "    \n",
      "    Each element of the tensor :attr:`input` is compared with the corresponding\n",
      "    element of the tensor :attr:`other` and an element-wise maximum is taken.\n",
      "    \n",
      "    The shapes of :attr:`input` and :attr:`other` don't need to match,\n",
      "    but they must be :ref:`broadcastable <broadcasting-semantics>`.\n",
      "    \n",
      "    .. math::\n",
      "        out_i = \\max(tensor_i, other_i)\n",
      "    \n",
      "    .. note:: When the shapes do not match, the shape of the returned output tensor\n",
      "              follows the :ref:`broadcasting rules <broadcasting-semantics>`.\n",
      "    \n",
      "    Args:\n",
      "        input (Tensor): the input tensor\n",
      "        other (Tensor): the second input tensor\n",
      "        out (Tensor, optional): the output tensor\n",
      "    \n",
      "    Example::\n",
      "    \n",
      "        >>> a = torch.randn(4)\n",
      "        >>> a\n",
      "        tensor([ 0.2942, -0.7416,  0.2653, -0.1584])\n",
      "        >>> b = torch.randn(4)\n",
      "        >>> b\n",
      "        tensor([ 0.8722, -1.7421, -0.4141, -0.5055])\n",
      "        >>> torch.max(a, b)\n",
      "        tensor([ 0.8722, -0.7416,  0.2653, -0.1584])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(torch.max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
